{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# AI Music Compositor using LangGraph\n",
    "\n",
    "## Overview\n",
    "This tutorial demonstrates how to build an AI-powered music composition system using LangGraph, a framework for creating workflows with language models. The system generates musical compositions based on user input, leveraging various components to create melody, harmony, rhythm, and style adaptations.\n",
    "\n",
    "## Motivation\n",
    "Creating music programmatically is a fascinating intersection of artificial intelligence and artistic expression. This project aims to explore how language models and graph-based workflows can be used to generate coherent musical pieces, providing a unique approach to AI-assisted music composition.\n",
    "\n",
    "## Key Components\n",
    "1. State Management: Utilizes a `MusicState` class to manage the workflow's state.\n",
    "2. Language Model: Employs ChatOpenAI (GPT-4) for generating musical components.\n",
    "3. Musical Functions:\n",
    "   - Melody Generator\n",
    "   - Harmony Creator\n",
    "   - Rhythm Analyzer\n",
    "   - Style Adapter\n",
    "4. MIDI Conversion: Transforms the composition into a playable MIDI file.\n",
    "5. LangGraph Workflow: Orchestrates the composition process using a state graph.\n",
    "6. Playback Functionality: Allows for immediate playback of the generated composition.\n",
    "\n",
    "## Method\n",
    "1. The workflow begins by generating a melody based on user input.\n",
    "2. It then creates harmony to complement the melody.\n",
    "3. A rhythm is analyzed and suggested for the melody and harmony.\n",
    "4. The composition is adapted to the specified musical style.\n",
    "5. The final composition is converted to MIDI format.\n",
    "6. The generated MIDI file can be played back using pygame.\n",
    "\n",
    "The entire process is orchestrated using LangGraph, which manages the flow of information between different components and ensures that each step builds upon the previous ones.\n",
    "\n",
    "## Conclusion\n",
    "This AI Music Compositor demonstrates the potential of combining language models with structured workflows to create musical compositions. By breaking down the composition process into discrete steps and leveraging the power of AI, we can generate unique musical pieces based on simple user inputs. This approach opens up new possibilities for AI-assisted creativity in music production and composition."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div style=\"text-align: center;\">\n",
    "\n",
    "<img src=\"../images/music_composer_agent_langgraph.svg\" alt=\"tts poem generator agent langgraph\" style=\"width:50%; height:auto;\">\n",
    "</div>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports\n",
    "\n",
    "Import all necessary modules and libraries for the AI Music Collaborator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pygame 2.6.0 (SDL 2.28.4, Python 3.12.0)\n",
      "Hello from the pygame community. https://www.pygame.org/contribute.html\n"
     ]
    }
   ],
   "source": [
    "# Import required libraries\n",
    "from typing import Dict, TypedDict\n",
    "from langgraph.graph import StateGraph, END\n",
    "from langchain_openai import ChatOpenAI\n",
    "from langchain.prompts import ChatPromptTemplate\n",
    "import music21\n",
    "import pygame\n",
    "import tempfile\n",
    "import os\n",
    "import random\n",
    "\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "# Load environment variables and set OpenAI API key\n",
    "load_dotenv()\n",
    "os.environ[\"OPENAI_API_KEY\"] = os.getenv('OPENAI_API_KEY')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## State Definition\n",
    "\n",
    "Define the MusicState class to hold the workflow's state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MusicState(TypedDict):\n",
    "    \"\"\"Define the structure of the state for the music generation workflow.\"\"\"\n",
    "    musician_input: str  # User's input describing the desired music\n",
    "    melody: str          # Generated melody\n",
    "    harmony: str         # Generated harmony\n",
    "    rhythm: str          # Generated rhythm\n",
    "    style: str           # Desired musical style\n",
    "    composition: str     # Complete musical composition\n",
    "    midi_file: str       # Path to the generated MIDI file"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LLM Initialization\n",
    "\n",
    "Initialize the Language Model (LLM) for generating musical components."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize the ChatOpenAI model\n",
    "llm = ChatOpenAI(model=\"gpt-4o-mini\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Component Functions\n",
    "\n",
    "Define the component functions for melody generation, harmony creation, rhythm analysis, style adaptation, and MIDI conversion."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def melody_generator(state: MusicState) -> Dict:\n",
    "    \"\"\"Generate a melody based on the user's input.\"\"\"\n",
    "    prompt = ChatPromptTemplate.from_template(\n",
    "        \"Generate a melody based on this input: {input}. Represent it as a string of notes in music21 format.\"\n",
    "    )\n",
    "    chain = prompt | llm\n",
    "    melody = chain.invoke({\"input\": state[\"musician_input\"]})\n",
    "    return {\"melody\": melody.content}\n",
    "\n",
    "def harmony_creator(state: MusicState) -> Dict:\n",
    "    \"\"\"Create harmony for the generated melody.\"\"\"\n",
    "    prompt = ChatPromptTemplate.from_template(\n",
    "        \"Create harmony for this melody: {melody}. Represent it as a string of chords in music21 format.\"\n",
    "    )\n",
    "    chain = prompt | llm\n",
    "    harmony = chain.invoke({\"melody\": state[\"melody\"]})\n",
    "    return {\"harmony\": harmony.content}\n",
    "\n",
    "def rhythm_analyzer(state: MusicState) -> Dict:\n",
    "    \"\"\"Analyze and suggest a rhythm for the melody and harmony.\"\"\"\n",
    "    prompt = ChatPromptTemplate.from_template(\n",
    "        \"Analyze and suggest a rhythm for this melody and harmony: {melody}, {harmony}. Represent it as a string of durations in music21 format.\"\n",
    "    )\n",
    "    chain = prompt | llm\n",
    "    rhythm = chain.invoke({\"melody\": state[\"melody\"], \"harmony\": state[\"harmony\"]})\n",
    "    return {\"rhythm\": rhythm.content}\n",
    "\n",
    "def style_adapter(state: MusicState) -> Dict:\n",
    "    \"\"\"Adapt the composition to the specified musical style.\"\"\"\n",
    "    prompt = ChatPromptTemplate.from_template(\n",
    "        \"Adapt this composition to the {style} style: Melody: {melody}, Harmony: {harmony}, Rhythm: {rhythm}. Provide the result in music21 format.\"\n",
    "    )\n",
    "    chain = prompt | llm\n",
    "    adapted = chain.invoke({\n",
    "        \"style\": state[\"style\"],\n",
    "        \"melody\": state[\"melody\"],\n",
    "        \"harmony\": state[\"harmony\"],\n",
    "        \"rhythm\": state[\"rhythm\"]\n",
    "    })\n",
    "    return {\"composition\": adapted.content}\n",
    "\n",
    "def midi_converter(state: MusicState) -> Dict:\n",
    "    \"\"\"Convert the composition to MIDI format and save it as a file.\"\"\"\n",
    "    # Create a new stream\n",
    "    piece = music21.stream.Score()\n",
    "\n",
    "    # Add the composition description to the stream as a text expression\n",
    "    description = music21.expressions.TextExpression(state[\"composition\"])\n",
    "    piece.append(description)\n",
    "\n",
    "    # Define a wide variety of scales and chords\n",
    "    scales = {\n",
    "        'C major': ['C', 'D', 'E', 'F', 'G', 'A', 'B'],\n",
    "        'C minor': ['C', 'D', 'Eb', 'F', 'G', 'Ab', 'Bb'],\n",
    "        'C harmonic minor': ['C', 'D', 'Eb', 'F', 'G', 'Ab', 'B'],\n",
    "        'C melodic minor': ['C', 'D', 'Eb', 'F', 'G', 'A', 'B'],\n",
    "        'C dorian': ['C', 'D', 'Eb', 'F', 'G', 'A', 'Bb'],\n",
    "        'C phrygian': ['C', 'Db', 'Eb', 'F', 'G', 'Ab', 'Bb'],\n",
    "        'C lydian': ['C', 'D', 'E', 'F#', 'G', 'A', 'B'],\n",
    "        'C mixolydian': ['C', 'D', 'E', 'F', 'G', 'A', 'Bb'],\n",
    "        'C locrian': ['C', 'Db', 'Eb', 'F', 'Gb', 'Ab', 'Bb'],\n",
    "        'C whole tone': ['C', 'D', 'E', 'F#', 'G#', 'A#'],\n",
    "        'C diminished': ['C', 'D', 'Eb', 'F', 'Gb', 'Ab', 'A', 'B'],\n",
    "    }\n",
    "\n",
    "    chords = {\n",
    "        'C major': ['C4', 'E4', 'G4'],\n",
    "        'C minor': ['C4', 'Eb4', 'G4'],\n",
    "        'C diminished': ['C4', 'Eb4', 'Gb4'],\n",
    "        'C augmented': ['C4', 'E4', 'G#4'],\n",
    "        'C dominant 7th': ['C4', 'E4', 'G4', 'Bb4'],\n",
    "        'C major 7th': ['C4', 'E4', 'G4', 'B4'],\n",
    "        'C minor 7th': ['C4', 'Eb4', 'G4', 'Bb4'],\n",
    "        'C half-diminished 7th': ['C4', 'Eb4', 'Gb4', 'Bb4'],\n",
    "        'C fully diminished 7th': ['C4', 'Eb4', 'Gb4', 'A4'],\n",
    "    }\n",
    "\n",
    "    def create_melody(scale_name, duration):\n",
    "        \"\"\"Create a melody based on a given scale.\"\"\"\n",
    "        melody = music21.stream.Part()\n",
    "        scale = scales[scale_name]\n",
    "        for _ in range(duration):\n",
    "            note = music21.note.Note(random.choice(scale) + '4')\n",
    "            note.quarterLength = 1\n",
    "            melody.append(note)\n",
    "        return melody\n",
    "\n",
    "    def create_chord_progression(duration):\n",
    "        \"\"\"Create a chord progression.\"\"\"\n",
    "        harmony = music21.stream.Part()\n",
    "        for _ in range(duration):\n",
    "            chord_name = random.choice(list(chords.keys()))\n",
    "            chord = music21.chord.Chord(chords[chord_name])\n",
    "            chord.quarterLength = 1\n",
    "            harmony.append(chord)\n",
    "        return harmony\n",
    "\n",
    "    # Parse the user input to determine scale and style\n",
    "    user_input = state['musician_input'].lower()\n",
    "    if 'minor' in user_input:\n",
    "        scale_name = 'C minor'\n",
    "    elif 'major' in user_input:\n",
    "        scale_name = 'C major'\n",
    "    else:\n",
    "        scale_name = random.choice(list(scales.keys()))\n",
    "\n",
    "    # Create a 7-second piece (7 beats at 60 BPM)\n",
    "    melody = create_melody(scale_name, 7)\n",
    "    harmony = create_chord_progression(7)\n",
    "\n",
    "    # Add a final whole note to make it exactly 8 beats (7 seconds at 60 BPM)\n",
    "    final_note = music21.note.Note(scales[scale_name][0] + '4')\n",
    "    final_note.quarterLength = 1\n",
    "    melody.append(final_note)\n",
    "    \n",
    "    final_chord = music21.chord.Chord(chords[scale_name.split()[0] + ' ' + scale_name.split()[1]])\n",
    "    final_chord.quarterLength = 1\n",
    "    harmony.append(final_chord)\n",
    "\n",
    "    # Add the melody and harmony to the piece\n",
    "    piece.append(melody)\n",
    "    piece.append(harmony)\n",
    "\n",
    "    # Set the tempo to 60 BPM\n",
    "    piece.insert(0, music21.tempo.MetronomeMark(number=60))\n",
    "\n",
    "    # Create a temporary MIDI file\n",
    "    with tempfile.NamedTemporaryFile(delete=False, suffix='.mid') as temp_midi:\n",
    "        piece.write('midi', temp_midi.name)\n",
    "    \n",
    "    return {\"midi_file\": temp_midi.name}\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Graph Construction\n",
    "\n",
    "Construct the LangGraph workflow for the AI Music Collaborator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize the StateGraph\n",
    "workflow = StateGraph(MusicState)\n",
    "\n",
    "# Add nodes to the graph\n",
    "workflow.add_node(\"melody_generator\", melody_generator)\n",
    "workflow.add_node(\"harmony_creator\", harmony_creator)\n",
    "workflow.add_node(\"rhythm_analyzer\", rhythm_analyzer)\n",
    "workflow.add_node(\"style_adapter\", style_adapter)\n",
    "workflow.add_node(\"midi_converter\", midi_converter)\n",
    "\n",
    "# Set the entry point of the graph\n",
    "workflow.set_entry_point(\"melody_generator\")\n",
    "\n",
    "# Add edges to connect the nodes\n",
    "workflow.add_edge(\"melody_generator\", \"harmony_creator\")\n",
    "workflow.add_edge(\"harmony_creator\", \"rhythm_analyzer\")\n",
    "workflow.add_edge(\"rhythm_analyzer\", \"style_adapter\")\n",
    "workflow.add_edge(\"style_adapter\", \"midi_converter\")\n",
    "workflow.add_edge(\"midi_converter\", END)\n",
    "\n",
    "# Compile the graph\n",
    "app = workflow.compile()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the Workflow\n",
    "\n",
    "Execute the AI Music Collaborator workflow to generate a musical composition."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Composition created\n",
      "MIDI file saved at: C:\\Users\\N7\\AppData\\Local\\Temp\\tmpu0n_jslr.mid\n"
     ]
    }
   ],
   "source": [
    "# Define input parameters\n",
    "inputs = {\n",
    "    \"musician_input\": \"Create a happy piano piece in C major\",\n",
    "    \"style\": \"Romantic era\"\n",
    "}\n",
    "\n",
    "# Invoke the workflow\n",
    "result = app.invoke(inputs)\n",
    "\n",
    "print(\"Composition created\")\n",
    "print(f\"MIDI file saved at: {result['midi_file']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MIDI Playback Function\n",
    "\n",
    "Define a function to play the generated MIDI file using pygame."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "To create and play a melody, run the following in a new cell:\n",
      "play_midi(result['midi_file'])\n"
     ]
    }
   ],
   "source": [
    "def play_midi(midi_file_path):\n",
    "    \"\"\"Play the generated MIDI file.\"\"\"\n",
    "    pygame.mixer.init()\n",
    "    pygame.mixer.music.load(midi_file_path)\n",
    "    pygame.mixer.music.play()\n",
    "\n",
    "    # Wait for playback to finish\n",
    "    while pygame.mixer.music.get_busy():\n",
    "        pygame.time.Clock().tick(10)\n",
    "    \n",
    "    # Clean up\n",
    "    pygame.mixer.quit()\n",
    "\n",
    "\n",
    "print(\"To create and play a melody, run the following in a new cell:\")\n",
    "print(\"play_midi(result['midi_file'])\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Play the Generated Music\n",
    "\n",
    "Execute this cell to play the generated MIDI file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "play_midi(result[\"midi_file\"])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
